package com.lw.leetcode.tree.c;

import com.lw.test.util.Utils;

/**
 * Created with IntelliJ IDEA.
 * 327. 区间和的个数
 *
 * @author liw
 * @version 1.0
 * @date 2022/6/8 16:55
 */
public class CountRangeSum {

    public static void main(String[] args) {
        CountRangeSum test = new CountRangeSum();
        //  [-2,5,-1]

        // 3
//        int[] arr = {-2, 5, -1};
//        int l = -2;
//        int u = 2;

        // 34
//        int[] arr = {-58, -52, -29, 55, -4, -79, 81, -43, 18, 42, -67, 85, -47, -78, -19, 55, 96, -97, -38, -28, 76, -7, -4, 52, 40, 41, 13, 68, -43, 97};
//        int l = -10;
//        int u = 5;

        int[] arr = Utils.getArr(10000, Integer.MIN_VALUE, Integer.MAX_VALUE);
        int l = -8000;
        int u = 15000;

        int i = test.countRangeSum(arr, l, u);
        System.out.println(l);
        System.out.println(u);
        System.out.println(i);
    }

    public int countRangeSum(int[] nums, int lower, int upper) {
        Node root = new Node( Integer.MIN_VALUE * 100000L, Integer.MAX_VALUE * 100000L);
        long sum = 0;
        int count = 0;
        add(root, 0);
        for (int num : nums) {
            sum += num;
            int b = find(root, sum - lower);
            int a = find(root, sum - upper - 1);
            count += (b - a);
            add(root, sum);
        }
        return count;
    }

    private int find(Node node, long val) {
        if (node == null) {
            return 0;
        }
        long st = node.st;
        long end = node.end;
        if (st == end || end <= val) {
            return node.count;
        }
        long m = st + ((end - st) >> 1);
        if (val <= m) {
            return find(node.left, val);
        } else {
            return (node.left == null ? 0 : node.left.count) + find(node.right, val);
        }
    }

    private void add(Node node, long val) {
        long st = node.st;
        long end = node.end;
        node.count++;
        if (st == end) {
            return;
        }
        long m = st + ((end - st) >> 1);
        if (val <= m) {
            if (node.left == null) {
                node.left = new Node(st, m);
            }
            add(node.left, val);
        } else {
            if (node.right == null) {
                node.right = new Node(m + 1, end);
            }
            add(node.right, val);
        }
    }

    private static class Node {
        private long st;
        private long end;
        private int count;
        private Node left;
        private Node right;
        private Node(long st, long end) {
            this.st = st;
            this.end = end;
        }
    }

}
